{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "13fd4510-107e-4eb4-b3af-c91da321c3bd",
   "metadata": {},
   "source": [
    "# Floating point management in Concrete"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "de110128-c4cc-42a3-a19e-aaba6bc9008d",
   "metadata": {},
   "source": [
    "In this tutorial, we are going to explain how to manage circuits with floating points. \n",
    "\n",
    "As we explain in the documentation, TFHE operations are limited to integers. However, it is most of the time not a limit, since it is possible to turn floating-point operations into integer operations. This is what we are going to study in this tutorial."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "0a8b3dc3-1629-4647-9a29-d757313162b9",
   "metadata": {},
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "from concrete import fhe\n",
    "from time import time\n",
    "\n",
    "from numpy.random import randint\n",
    "from numpy.random import rand\n",
    "from numpy import round"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "bf3b2a7d-6b7e-4c05-b54f-c23c2bdffccb",
   "metadata": {},
   "source": [
    "## Starting with an integer circuit"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6b146f03-2b38-4aa7-a0b3-7cd640c44b88",
   "metadata": {},
   "source": [
    "Let's start with a very simple circuit, directly on integers. We'll take an example inspired by the README example."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "8af17025-e184-4238-b579-a68a454cd5a7",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Compilation and test look good, with FHE execution time of about 0.00 seconds per inference\n"
     ]
    }
   ],
   "source": [
    "nb_bits = 4\n",
    "length_inputset = 100\n",
    "nb_test_samples = 16\n",
    "\n",
    "\n",
    "@fhe.compiler({\"x\": \"encrypted\", \"y\": \"encrypted\"})\n",
    "def add_integers(x, y):\n",
    "    return x + y\n",
    "\n",
    "\n",
    "# Compile\n",
    "inputset = [(randint(2**nb_bits), randint(2**nb_bits)) for _ in range(length_inputset)]\n",
    "circuit = add_integers.compile(inputset)\n",
    "\n",
    "# Check\n",
    "time_begin = time()\n",
    "\n",
    "for _ in range(nb_test_samples):\n",
    "    random_sample = (randint(2**nb_bits), randint(2**nb_bits))\n",
    "    assert circuit.encrypt_run_decrypt(*random_sample) == add_integers(*random_sample)\n",
    "\n",
    "print(\n",
    "    \"Compilation and test look good, with FHE execution time of about \"\n",
    "    f\"{(time() - time_begin) / nb_test_samples:.2f} seconds per inference\"\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b798da0d-2f5e-469c-b263-4ad4b9b8949a",
   "metadata": {},
   "source": [
    "Here, we have defined a function `add_integers`, which takes 2 encrypted inputs and add them. We have compiled this function using an inputset of random inputs of `nb_bits = 4` bits. At the end, we check the FHE execution, by comparing `encrypt_run_decrypt` execution (which conveniently pack encryption, FHE run and decryption together, for testing purposes) and the clear execution `add_integers(*random_sample)`. "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3eec2673-500a-4eff-8efb-8e1f63294881",
   "metadata": {},
   "source": [
    "Let's even see the MLIR circuit"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "13708d3e-edc6-4235-8838-72e4fbaae5d3",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "module {\n",
      "  func.func @main(%arg0: !FHE.eint<5>, %arg1: !FHE.eint<5>) -> !FHE.eint<5> {\n",
      "    %0 = \"FHE.add_eint\"(%arg0, %arg1) : (!FHE.eint<5>, !FHE.eint<5>) -> !FHE.eint<5>\n",
      "    return %0 : !FHE.eint<5>\n",
      "  }\n",
      "}\n"
     ]
    }
   ],
   "source": [
    "print(circuit.mlir)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ff805dec-57b4-422e-b35a-e05529148bf8",
   "metadata": {},
   "source": [
    "Here we see that inputs and outputs are considered as 5b integers"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1f62fa64-d0ca-4d48-8f66-fa6797d6611e",
   "metadata": {},
   "source": [
    "## Trying to use it with floats"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f28fa70f-b41a-4617-b68e-8248995e87d8",
   "metadata": {},
   "source": [
    "Now, let's try to use this circuit with floats. As you can imagine, it's not going to work."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "ab241720-ac18-4147-b99f-ae905c4eeec8",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Expected argument 0 to be EncryptedScalar<uint5> but it's EncryptedScalar<float64>\n"
     ]
    }
   ],
   "source": [
    "random_sample = (1.5, 2.42)\n",
    "try:\n",
    "    circuit.encrypt_run_decrypt(*random_sample)\n",
    "except Exception as err:\n",
    "    print(err)\n",
    "    pass"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "da4ba411-5f0c-41bd-ad47-283ed797f4ef",
   "metadata": {},
   "source": [
    "It raises an error `ValueError: Expected argument 0 to be EncryptedScalar<uint5> but it's EncryptedScalar<float64>`. Let's see now how to deal with this situation."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5c19943d-ec4c-4630-b19a-de16c4c1811e",
   "metadata": {},
   "source": [
    "## Creating a circuit for floats"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "30310f38-4a14-4f10-8ad0-cc828b124fb2",
   "metadata": {},
   "source": [
    "What we are going to do is:\n",
    "- chose a scaling factor\n",
    "- multiply floats by this scaling factor\n",
    "- round to integer\n",
    "- encrypt\n",
    "- make FHE computations over integers, classically\n",
    "- decrypt the result\n",
    "- unscale\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d9c61142-a047-4c46-a91a-441640be9538",
   "metadata": {},
   "source": [
    "scaling_factor = 100"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "26239c84-882a-4b43-b5f1-8adca56313e5",
   "metadata": {},
   "source": [
    "Here, we chose a scaling factor which is a power of 10, to be easier to read, but it could be anything, including a float value."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "c357758a-d380-4e43-95fc-9b07b3fec61e",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Showing an example\n",
      "random_sample=(1.583812144515164, 0.38770109380525164)\n",
      "scaled_sample=(158, 39)\n",
      "scaled_result=197\n",
      "\n",
      "FHE:   1.97\n",
      "clear: 1.9715132383204157\n",
      "bounded error: 0.02\n",
      "\n",
      "Compilation and test look good, with FHE execution time of about 0.00 seconds per inference\n"
     ]
    }
   ],
   "source": [
    "max_value_for_floats = 2.5\n",
    "scaling_factor = 100\n",
    "\n",
    "\n",
    "@fhe.compiler({\"x\": \"encrypted\", \"y\": \"encrypted\"})\n",
    "def add_integers(x, y):\n",
    "    return x + y\n",
    "\n",
    "\n",
    "def are_almost_the_same(a, b, threshold):\n",
    "    abs_diff = abs(b - a)\n",
    "    assert abs_diff <= threshold, f\"Too far {a=} {b=} {abs_diff=} > {threshold=}\"\n",
    "\n",
    "\n",
    "# Compile\n",
    "inputset = [\n",
    "    (\n",
    "        round(scaling_factor * rand() * max_value_for_floats).astype(np.uint32),\n",
    "        round(scaling_factor * rand() * max_value_for_floats).astype(np.uint32),\n",
    "    )\n",
    "    for _ in range(length_inputset)\n",
    "]\n",
    "circuit = add_integers.compile(inputset)\n",
    "\n",
    "# Check\n",
    "verbose = True\n",
    "time_begin = time()\n",
    "\n",
    "for _ in range(nb_test_samples):\n",
    "    # Take a random float input\n",
    "    random_sample = (rand() * max_value_for_floats, rand() * max_value_for_floats)\n",
    "    if verbose:\n",
    "        print(\"Showing an example\")\n",
    "        print(f\"{random_sample=}\")\n",
    "\n",
    "    # Scale it and round\n",
    "    scaled_sample = (\n",
    "        round(random_sample[0] * scaling_factor).astype(np.uint32),\n",
    "        round(random_sample[1] * scaling_factor).astype(np.uint32),\n",
    "    )\n",
    "    if verbose:\n",
    "        print(f\"{scaled_sample=}\")\n",
    "\n",
    "    # Encrypt\n",
    "    encrypted_scaled_sample = circuit.encrypt(*scaled_sample)\n",
    "\n",
    "    # Computations in FHE\n",
    "    encrypted_scaled_result = circuit.run(*encrypted_scaled_sample)\n",
    "\n",
    "    # Decrypt\n",
    "    scaled_result = circuit.decrypt(encrypted_scaled_result)\n",
    "    if verbose:\n",
    "        print(f\"{scaled_result=}\")\n",
    "\n",
    "    # Unscale\n",
    "    result = scaled_result * 1.0 / scaling_factor\n",
    "\n",
    "    bounded_error = 2.0 / scaling_factor\n",
    "    clear_result = random_sample[0] + random_sample[1]\n",
    "\n",
    "    if verbose:\n",
    "        print()\n",
    "        print(f\"FHE:   {result}\")\n",
    "        print(f\"clear: {clear_result}\")\n",
    "        print(f\"bounded error: {bounded_error}\")\n",
    "\n",
    "    are_almost_the_same(result, clear_result, bounded_error)\n",
    "\n",
    "    verbose = False\n",
    "\n",
    "print(\n",
    "    \"\\nCompilation and test look good, with FHE execution time of about \"\n",
    "    f\"{(time() - time_begin) / nb_test_samples:.2f} seconds per inference\"\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "68c26f6b-20bb-40e5-a343-9f4892c7d86b",
   "metadata": {},
   "source": [
    "Here, our testing loop checks that the results are the expected ones, with a tolerance error of `bounded_error = 2. / scaling_factor`. We also see one example: \n",
    "- random_sample is a pair of floats\n",
    "- scaled_sample is this same pair multiplied by the scaling factor and rounded to integers\n",
    "- scaled_result is the result of the FHE computation, over integers\n",
    "- at the end, we see the FHE result and the clear result. They are close but not the same, due to rounding errors.\n",
    "\n",
    "To have more precise computations, one needs to take a larger scaling factor. However, keep in mind it has impact on performances, since integers will be larger. "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "9ef68e15-8eb8-41ed-924c-a767ef8fbee8",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "module {\n",
      "  func.func @main(%arg0: !FHE.eint<9>, %arg1: !FHE.eint<9>) -> !FHE.eint<9> {\n",
      "    %0 = \"FHE.add_eint\"(%arg0, %arg1) : (!FHE.eint<9>, !FHE.eint<9>) -> !FHE.eint<9>\n",
      "    return %0 : !FHE.eint<9>\n",
      "  }\n",
      "}\n"
     ]
    }
   ],
   "source": [
    "print(circuit.mlir)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "87123b2e-ae99-444c-bf19-95a292536342",
   "metadata": {},
   "source": [
    "Here, we see that our integers in the circuits have 9 bits, which corresponds to code integers which are smaller than `max_value_for_floats * scaling_factor = 250` plus 1 for the addition carry. Taking `scaling_factor = 1000` for example would give more precise results (`bounded_error` reduces from 0.02 to 0.0002) but now the circuits would have integers of 13 bits. \n",
    "\n",
    "As our circuits doesn't have programmable bootstrapping (PBS), it doesn't have a big impact, but if we had non linear operations (as we have in the following), the timing impact will be much more important.\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d3502aef-767b-4f96-b465-771763c1b4ea",
   "metadata": {},
   "source": [
    "## More complex case with Programmable Bootstrapping"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "a7f6e053-bc1a-45d3-aa15-9bbd63ae382a",
   "metadata": {},
   "source": [
    "Let's do the same for computing the norm of a vector. Here, we'll have a square root so some PBS."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "db901dcb-412c-4b52-8c1e-3d41782bc650",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Maximal bitwidth in the circuit: 10\n",
      "\n",
      "Showing an example\n",
      "random_sample=(0.6597925446542308, 2.2648103856817494)\n",
      "scaled_sample=(5, 18)\n",
      "scaled_result=19\n",
      "\n",
      "FHE:   2.375\n",
      "clear: 2.358960000736176\n",
      "bounded error: 0.25\n",
      "\n",
      "Compilation and test look good, with FHE execution time of about 11.16 seconds per inference\n"
     ]
    }
   ],
   "source": [
    "max_value_for_floats = 2.5\n",
    "scaling_factor = 8\n",
    "nb_test_samples = 10\n",
    "\n",
    "\n",
    "@fhe.compiler({\"x\": \"encrypted\", \"y\": \"encrypted\"})\n",
    "def norm(x, y):\n",
    "    return np.round(fhe.univariate(lambda x: np.sqrt(x))(x**2 + y**2)).astype(np.int64)\n",
    "\n",
    "\n",
    "# Compile\n",
    "inputset = [\n",
    "    (\n",
    "        round(scaling_factor * rand() * max_value_for_floats).astype(np.uint32),\n",
    "        round(scaling_factor * rand() * max_value_for_floats).astype(np.uint32),\n",
    "    )\n",
    "    for _ in range(length_inputset)\n",
    "]\n",
    "circuit = norm.compile(inputset, show_mlir=False)\n",
    "\n",
    "print(f\"Maximal bitwidth in the circuit: {circuit.graph.maximum_integer_bit_width()}\\n\")\n",
    "\n",
    "# Check\n",
    "verbose = True\n",
    "time_begin = time()\n",
    "\n",
    "for _ in range(nb_test_samples):\n",
    "    # Take a random float input\n",
    "    random_sample = (rand() * max_value_for_floats, rand() * max_value_for_floats)\n",
    "    if verbose:\n",
    "        print(\"Showing an example\")\n",
    "        print(f\"{random_sample=}\")\n",
    "\n",
    "    # Scale it and round\n",
    "    scaled_sample = (\n",
    "        round(random_sample[0] * scaling_factor).astype(np.uint32),\n",
    "        round(random_sample[1] * scaling_factor).astype(np.uint32),\n",
    "    )\n",
    "    if verbose:\n",
    "        print(f\"{scaled_sample=}\")\n",
    "\n",
    "    # Encrypt\n",
    "    encrypted_scaled_sample = circuit.encrypt(*scaled_sample)\n",
    "\n",
    "    # Computations in FHE\n",
    "    encrypted_scaled_result = circuit.run(*encrypted_scaled_sample)\n",
    "\n",
    "    # Decrypt\n",
    "    scaled_result = circuit.decrypt(encrypted_scaled_result)\n",
    "    if verbose:\n",
    "        print(f\"{scaled_result=}\")\n",
    "\n",
    "    # Unscale\n",
    "    result = scaled_result * 1.0 / scaling_factor\n",
    "\n",
    "    bounded_error = 2.0 / scaling_factor\n",
    "    clear_result = np.sqrt(random_sample[0] ** 2 + random_sample[1] ** 2)\n",
    "\n",
    "    if verbose:\n",
    "        print()\n",
    "        print(f\"FHE:   {result}\")\n",
    "        print(f\"clear: {clear_result}\")\n",
    "        print(f\"bounded error: {bounded_error}\")\n",
    "\n",
    "    are_almost_the_same(result, clear_result, bounded_error)\n",
    "\n",
    "    verbose = False\n",
    "\n",
    "print(\n",
    "    \"\\nCompilation and test look good, with FHE execution time of about \"\n",
    "    f\"{(time() - time_begin) / nb_test_samples:.2f} seconds per inference\"\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "560956ce-8791-4dd8-b7c0-c9c489f22063",
   "metadata": {},
   "source": [
    "We can remark that here, the function is much slower, since it includes PBS. Furthermore, the input of the PBS is as large as twice the square of typical inputs, which can quickly become large. That's why we have had to reduce the scaling factor and the precision. When more precision is needed, one would have to be a bit more patient."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "85995459-8036-4b18-92f0-261d524b1a6e",
   "metadata": {},
   "source": [
    "## Final example"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "afc5c7eb-db95-49c7-accc-5897e53db414",
   "metadata": {},
   "source": [
    "We finish with a more complex example, where one needs to understand what kind of computations are done and cancel some scaling factors during the computation, to ensure correct computations.\n",
    "\n",
    "Let's suppose we want to convert `f(x, y) = (x**2 + 4y) / 1.33` to FHE. What we'll do is that, we'll scale `x` and `y` with a scaling factor `scaling_factor`, before encryption. Then, in `x**2`, scaling factors will multiply (so we'll have a scaling factor of `scaling_factor**2`) while in `4y` scaling factor would remain as `scaling_factor`. To keep the addition homogeneous, we have two possibilities:\n",
    "- multiply `4y` by `scaling_factor` in the FHE circuit, and at the end, unscale by dividing by `scaling_factor**2`\n",
    "- or, divide `x**2` by `scaling_factor`in the FHE circuit, and at the end, unscale by dividing by `scaling_factor`.\n",
    "We have chose the second approach, since it avoids to enlarge too much integers, making that our FHE execution is more efficient.\n",
    "\n",
    "The error bound is more complicated to compute: we can use an estimation.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "e3fdb20c-3e28-489e-8ede-640331674b7a",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Maximal bitwidth in the circuit: 7\n",
      "\n",
      "Showing an example\n",
      "random_sample=(0.43374356260757363, 1.6491834206477578)\n",
      "scaled_sample=(3, 13)\n",
      "scaled_result=40\n",
      "\n",
      "FHE:   5.0\n",
      "clear: 5.10140388022146\n",
      "bounded error: 0.375\n",
      "\n",
      "Compilation and test look good, with FHE execution time of about 0.82 seconds per inference\n"
     ]
    }
   ],
   "source": [
    "max_value_for_floats = 2\n",
    "scaling_factor = 8\n",
    "nb_test_samples = 8\n",
    "\n",
    "\n",
    "def special_function_in_clear(x, y):\n",
    "    u = x**2\n",
    "    v = u + 4 * y\n",
    "    return v / 1.33\n",
    "\n",
    "\n",
    "@fhe.compiler({\"x\": \"encrypted\", \"y\": \"encrypted\"})\n",
    "def special_function(x, y):\n",
    "    u = fhe.univariate(lambda x: x**2 // scaling_factor)(x)\n",
    "    v = u + 4 * y\n",
    "    return np.round(fhe.univariate(lambda x: x / 1.33)(v)).astype(np.int64)\n",
    "\n",
    "\n",
    "# Compile\n",
    "inputset = [\n",
    "    (\n",
    "        round(scaling_factor * rand() * max_value_for_floats).astype(np.uint32),\n",
    "        round(scaling_factor * rand() * max_value_for_floats).astype(np.uint32),\n",
    "    )\n",
    "    for _ in range(length_inputset)\n",
    "]\n",
    "circuit = special_function.compile(inputset, show_mlir=False)\n",
    "\n",
    "print(f\"Maximal bitwidth in the circuit: {circuit.graph.maximum_integer_bit_width()}\\n\")\n",
    "\n",
    "# Check\n",
    "verbose = True\n",
    "time_begin = time()\n",
    "\n",
    "for _ in range(nb_test_samples):\n",
    "    # Take a random float input\n",
    "    random_sample = (rand() * max_value_for_floats, rand() * max_value_for_floats)\n",
    "    if verbose:\n",
    "        print(\"Showing an example\")\n",
    "        print(f\"{random_sample=}\")\n",
    "\n",
    "    # Scale it and round\n",
    "    scaled_sample = (\n",
    "        round(random_sample[0] * scaling_factor).astype(np.uint32),\n",
    "        round(random_sample[1] * scaling_factor).astype(np.uint32),\n",
    "    )\n",
    "    if verbose:\n",
    "        print(f\"{scaled_sample=}\")\n",
    "\n",
    "    # Encrypt\n",
    "    encrypted_scaled_sample = circuit.encrypt(*scaled_sample)\n",
    "\n",
    "    # Computations in FHE\n",
    "    encrypted_scaled_result = circuit.run(*encrypted_scaled_sample)\n",
    "\n",
    "    # Decrypt\n",
    "    scaled_result = circuit.decrypt(encrypted_scaled_result)\n",
    "    if verbose:\n",
    "        print(f\"{scaled_result=}\")\n",
    "\n",
    "    # Unscale\n",
    "    result = scaled_result * 1.0 / scaling_factor\n",
    "\n",
    "    bounded_error = 3.0 / scaling_factor\n",
    "    clear_result = special_function_in_clear(*random_sample)\n",
    "\n",
    "    if verbose:\n",
    "        print()\n",
    "        print(f\"FHE:   {result}\")\n",
    "        print(f\"clear: {clear_result}\")\n",
    "        print(f\"bounded error: {bounded_error}\")\n",
    "\n",
    "    are_almost_the_same(result, clear_result, bounded_error)\n",
    "\n",
    "    verbose = False\n",
    "\n",
    "print(\n",
    "    \"\\nCompilation and test look good, with FHE execution time of about \"\n",
    "    f\"{(time() - time_begin) / nb_test_samples:.2f} seconds per inference\"\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3b228ca3-017f-4502-9b9d-ff17d9cd9171",
   "metadata": {},
   "source": [
    "## Conclusion\n",
    "\n",
    "As one knows, floats are not natively supported in TFHE and thus in Concrete. However, we have shown in this tutorial that by scaling floats to integers, it's completely possible to make the computations in FHE. Finally, we'll add that such techniques, called quantization techniques, are directly integrated into Concrete ML for what is related to machine learning. More information are also given in the Concrete ML documentation, see https://docs.zama.ai/concrete-ml/explanations/quantization. "
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.10.7 64-bit",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.15"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
